#include "lenet.h"
#include "lenet_forward0.comp.gen.h"
#include "lenet_forward1.comp.gen.h"
#include "lenet_forward2.comp.gen.h"
#include "lenet_forward3.comp.gen.h"
#include "lenet_forward4.comp.gen.h"
#include "lenet_forward5.comp.gen.h"
#include "lenet_softmax.comp.gen.h"
#include "lenet_backward0.comp.gen.h"
#include "lenet_backward1.comp.gen.h"
#include "lenet_backward2.comp.gen.h"
#include "lenet_backward3.comp.gen.h"
#include "lenet_backward4.comp.gen.h"
#include "lenet_backward5.comp.gen.h"
#include "lenet_delta0.comp.gen.h"
#include "lenet_delta2.comp.gen.h"
#include "lenet_delta4.comp.gen.h"
#include "lenet_delta5.comp.gen.h"
#include "lenet_update.comp.gen.h"

#define VkResultCheck(x) { VkResult res = (x); if(VK_SUCCESS != res) return res; }

#define GETLENGTH(array) (sizeof(array)/sizeof(*(array)))

#define GETCOUNT(array)  (sizeof(array)/sizeof(float))


static VkResult CreateStorageBuffer(DeviceContext* ctx, StorageBuffer* pBuffer, const VkDeviceSize size, const VkMemoryPropertyFlags flag)
{

    VkBufferCreateInfo bufferInfo = { VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO };
    bufferInfo.usage = VK_BUFFER_USAGE_STORAGE_BUFFER_BIT;
    bufferInfo.size = size;
    VkResultCheck(vkCreateBuffer(ctx->device, &bufferInfo, NULL, &pBuffer->Buffer));

    VkMemoryRequirements memRequirements;
    vkGetBufferMemoryRequirements(ctx->device, pBuffer->Buffer, &memRequirements);
    VkMemoryAllocateInfo memoryAllocInfo = { VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO };
    memoryAllocInfo.allocationSize = memRequirements.size;
    for (uint32_t i = 0; i < ctx->physicalDeviceMemoryProperty.memoryTypeCount; ++i)
    {
        if (memRequirements.memoryTypeBits >> i & 1)
        {
            if ((flag & ctx->physicalDeviceMemoryProperty.memoryTypes[i].propertyFlags) == flag)
            {
                memoryAllocInfo.memoryTypeIndex = i;
            }
        }
    }
    VkResultCheck(vkAllocateMemory(ctx->device, &memoryAllocInfo, NULL, &pBuffer->Memory));
    vkBindBufferMemory(ctx->device, pBuffer->Buffer, pBuffer->Memory, 0);
    return VK_SUCCESS;
}

static void DestroyStorageBuffer(DeviceContext* ctx, StorageBuffer* pBuffer)
{
    if (pBuffer->Buffer)
    {
        vkDestroyBuffer(ctx->device, pBuffer->Buffer, NULL);
        pBuffer->Buffer = VK_NULL_HANDLE;
    }
    if (pBuffer->Memory)
    {
        vkFreeMemory(ctx->device, pBuffer->Memory, NULL);
        pBuffer->Memory = VK_NULL_HANDLE;
    }
}

static VkResult CreatePipelineFromSpirv(DeviceContext* ctx, VkPipeline* pPipeline, const uint32_t* data, const uint32_t length)
{
    VkShaderModule module = VK_NULL_HANDLE;
    VkShaderModuleCreateInfo shaderModuleInfo = { VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO };
    shaderModuleInfo.pCode = data;
    shaderModuleInfo.codeSize = length;

    VkResultCheck(vkCreateShaderModule(ctx->device, &shaderModuleInfo, NULL, &module));

    VkComputePipelineCreateInfo computePipelineInfo = { VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO };
    computePipelineInfo.basePipelineHandle = VK_NULL_HANDLE;
    computePipelineInfo.layout = ctx->pipelineLayout;
    computePipelineInfo.stage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
    computePipelineInfo.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT;
    computePipelineInfo.stage.module = module;
    computePipelineInfo.stage.pName = "main";
    VkResultCheck(vkCreateComputePipelines(ctx->device, VK_NULL_HANDLE, 1, &computePipelineInfo, NULL, pPipeline));
    vkDestroyShaderModule(ctx->device, module, NULL);
    return VK_SUCCESS;
}

static void Forward(DeviceContext* ctx, TrainCache* cache, VkPipeline pipeline, uint32_t threads)
{
    VkBufferMemoryBarrier barrier[] =
    {
        { VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER },
        { VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER }
    };
    for (int i = 0; i < GETLENGTH(barrier); ++i)
    {
        barrier[i].srcAccessMask = VK_ACCESS_SHADER_WRITE_BIT;
        barrier[i].dstAccessMask = VK_ACCESS_SHADER_READ_BIT;
        barrier[i].srcQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED;
        barrier[i].dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED;
        barrier[i].size = VK_WHOLE_SIZE;
    }
    barrier[0].buffer = ctx->lenet.Buffer;
    barrier[1].buffer = cache ? cache->feature.Buffer : ctx->feature.Buffer;
    vkCmdBindDescriptorSets(ctx->commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, ctx->pipelineLayout, 0, 1, &ctx->descriptorSet, 0, NULL);
    vkCmdBindPipeline(ctx->commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline);
    vkCmdPipelineBarrier(ctx->commandBuffer,
        VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0,
        0, NULL,
        sizeof(barrier) / sizeof(*barrier),
        barrier, 0, NULL);
    vkCmdDispatch(ctx->commandBuffer,
        (threads + THREADGROUP_SIZE - 1) / THREADGROUP_SIZE,
        cache ? cache->batchSize : 1, 1);

}

static void Softmax(DeviceContext* ctx, TrainCache* cache)
{
    VkBufferMemoryBarrier barrier[] =
    {
        { VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER },
        { VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER },
        { VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER },
        { VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER },
    };
    for (int i = 0; i < GETLENGTH(barrier); ++i)
    {
        barrier[i].srcAccessMask = VK_ACCESS_SHADER_WRITE_BIT;
        barrier[i].dstAccessMask = VK_ACCESS_SHADER_READ_BIT;
        barrier[i].srcQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED;
        barrier[i].dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED;
        barrier[i].size = VK_WHOLE_SIZE;
    }
    barrier[0].buffer = ctx->lenet.Buffer;
    barrier[1].buffer = cache->feature.Buffer;
    barrier[2].buffer = cache->error.Buffer;
    barrier[3].buffer = cache->label.Buffer;

    vkCmdBindDescriptorSets(ctx->commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, ctx->pipelineLayout, 0, 1, &ctx->descriptorSet, 0, NULL);
    vkCmdBindPipeline(ctx->commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, ctx->softmaxPipeline);
    vkCmdPipelineBarrier(ctx->commandBuffer,
        VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0,
        0, NULL, GETLENGTH(barrier), barrier, 0, NULL);
    vkCmdDispatch(ctx->commandBuffer, 1, cache->batchSize, 1);
}

static void Backward(DeviceContext* ctx, TrainCache* cache, VkPipeline pipeline, uint32_t threads)
{
    VkBufferMemoryBarrier barrier[] =
    {
        { VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER },
        { VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER },
        { VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER },
    };
    for (int i = 0; i < GETLENGTH(barrier); ++i)
    {
        barrier[i].srcAccessMask = VK_ACCESS_SHADER_WRITE_BIT;
        barrier[i].dstAccessMask = VK_ACCESS_SHADER_READ_BIT;
        barrier[i].srcQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED;
        barrier[i].dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED;
        barrier[i].size = VK_WHOLE_SIZE;
    }
    barrier[0].buffer = ctx->lenet.Buffer;
    barrier[1].buffer = cache->feature.Buffer;
    barrier[2].buffer = cache->error.Buffer;
    vkCmdBindDescriptorSets(ctx->commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, ctx->pipelineLayout, 0, 1, &ctx->descriptorSet, 0, NULL);
    vkCmdBindPipeline(ctx->commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline);
    vkCmdPipelineBarrier(ctx->commandBuffer,
        VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0,
        0, NULL,
        sizeof(barrier) / sizeof(*barrier),
        barrier, 0, NULL);
    vkCmdDispatch(ctx->commandBuffer,
        (threads + THREADGROUP_SIZE - 1) / THREADGROUP_SIZE,
        cache->batchSize, 1);

}

static void Delta(DeviceContext* ctx, TrainCache* cache, VkPipeline pipeline, uint32_t threads)
{
    VkBufferMemoryBarrier barrier[] =
    {
        { VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER },
        { VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER },
        { VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER },
        { VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER },
    };
    for (int i = 0; i < GETLENGTH(barrier); ++i)
    {
        barrier[i].srcAccessMask = VK_ACCESS_SHADER_WRITE_BIT;
        barrier[i].dstAccessMask = VK_ACCESS_SHADER_READ_BIT;
        barrier[i].srcQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED;
        barrier[i].dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED;
        barrier[i].size = VK_WHOLE_SIZE;
    }
    barrier[0].buffer = ctx->lenet.Buffer;
    barrier[1].buffer = cache->feature.Buffer;
    barrier[2].buffer = cache->error.Buffer;
    barrier[3].buffer = cache->delta.Buffer;
    vkCmdBindDescriptorSets(ctx->commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, ctx->pipelineLayout, 0, 1, &ctx->descriptorSet, 0, NULL);
    vkCmdBindPipeline(ctx->commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline);
    vkCmdPipelineBarrier(ctx->commandBuffer,
        VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0,
        0, NULL,
        sizeof(barrier) / sizeof(*barrier),
        barrier, 0, NULL);
    vkCmdDispatch(ctx->commandBuffer,
        (threads + THREADGROUP_SIZE - 1) / THREADGROUP_SIZE,
        cache->batchSize, 1);
}

static void Update(DeviceContext* ctx, TrainCache* cache)
{
    VkBufferMemoryBarrier barrier[] =
    {
        { VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER },
        { VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER },
    };
    for (int i = 0; i < GETLENGTH(barrier); ++i)
    {
        barrier[i].srcAccessMask = VK_ACCESS_SHADER_WRITE_BIT;
        barrier[i].dstAccessMask = VK_ACCESS_SHADER_READ_BIT;
        barrier[i].srcQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED;
        barrier[i].dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED;
        barrier[i].size = VK_WHOLE_SIZE;
    }
    uint32_t constant[] = { cache->batchSize, GETCOUNT(LeNet5) };
    barrier[0].buffer = ctx->lenet.Buffer;
    barrier[1].buffer = cache->delta.Buffer;
    vkCmdBindDescriptorSets(ctx->commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, ctx->pipelineLayout, 0, 1, &ctx->descriptorSet, 0, NULL);
    vkCmdBindPipeline(ctx->commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, ctx->updatePipeline);
    vkCmdPipelineBarrier(ctx->commandBuffer,
        VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0,
        0, NULL,
        sizeof(barrier) / sizeof(*barrier),
        barrier, 0, NULL);
    vkCmdPushConstants(ctx->commandBuffer, ctx->pipelineLayout, VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(constant), constant);
    vkCmdDispatch(ctx->commandBuffer, (GETCOUNT(LeNet5) + THREADGROUP_SIZE - 1) / THREADGROUP_SIZE, 1, 1);
}


VkResult CreateDeviceContext(DeviceContext* ctx)
{
    memset(ctx, 0, sizeof(DeviceContext));

    VkApplicationInfo appInfo = { VK_STRUCTURE_TYPE_APPLICATION_INFO };
    appInfo.applicationVersion = VK_MAKE_VERSION(1, 0, 0);
    appInfo.engineVersion = VK_MAKE_VERSION(1, 0, 0);
    appInfo.apiVersion = VK_API_VERSION_1_0;
    VkInstanceCreateInfo createInfo = { VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO };
    createInfo.pApplicationInfo = &appInfo;

#ifdef _DEBUG
    const char* validationExt[] = { VK_EXT_DEBUG_REPORT_EXTENSION_NAME };
    const char* validationLayers[] = { "VK_LAYER_LUNARG_standard_validation" };
    createInfo.enabledLayerCount = sizeof(validationLayers) / sizeof(*validationLayers);
    createInfo.enabledExtensionCount = sizeof(validationExt) / sizeof(*validationExt);
    createInfo.ppEnabledLayerNames = validationLayers;
    createInfo.ppEnabledExtensionNames = validationExt;
#endif // _DEBUG

    VkResultCheck(vkCreateInstance(&createInfo, NULL, &ctx->instance));
    VkPhysicalDevice physicalDevice[16] = { 0 };
    VkPhysicalDeviceProperties deviceProperty[16] = { 0 };
    uint32_t deviceCount = sizeof(physicalDevice) / sizeof(*physicalDevice);
    VkResultCheck(vkEnumeratePhysicalDevices(ctx->instance, &deviceCount, physicalDevice));
    uint8_t physicalDevicePriorities[] = { 1, 2, 4, 3, 0 };
    uint8_t usePhysicalDevice = 0;
    for (uint32_t i = 0; i < deviceCount; ++i)
    {
        vkGetPhysicalDeviceProperties(physicalDevice[i], &deviceProperty[i]);
        if (physicalDevicePriorities[deviceProperty[i].deviceType] >
            physicalDevicePriorities[deviceProperty[usePhysicalDevice].deviceType])
        {
            usePhysicalDevice = i;
        }
    }
    printf("Use Physical Device: %s \n", deviceProperty[usePhysicalDevice].deviceName);

    float priorities = 1.0f;
    VkDeviceCreateInfo deviceInfo = { VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO };
    VkDeviceQueueCreateInfo queneInfo = { VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO };
    queneInfo.queueCount = 1;
    queneInfo.pQueuePriorities = &priorities;
    deviceInfo.queueCreateInfoCount = 1;
    deviceInfo.pQueueCreateInfos = &queneInfo;
    VkResultCheck(vkCreateDevice(physicalDevice[usePhysicalDevice], &deviceInfo, NULL, &ctx->device));
    vkGetDeviceQueue(ctx->device, 0, 0, &ctx->queue);

    VkCommandPoolCreateInfo commandPoolInfo = { VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO };
    commandPoolInfo.flags = VK_COMMAND_POOL_CREATE_RESET_COMMAND_BUFFER_BIT;
    VkResultCheck(vkCreateCommandPool(ctx->device, &commandPoolInfo, NULL, &ctx->commandPool));

    VkCommandBufferAllocateInfo commandBufferInfo = { VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO };
    commandBufferInfo.commandPool = ctx->commandPool;
    commandBufferInfo.commandBufferCount = 1;

    VkResultCheck(vkAllocateCommandBuffers(ctx->device, &commandBufferInfo, &ctx->commandBuffer));

    VkDescriptorSetLayoutBinding descriptorSetLayoutBinding[5] = { { 0 } };
    VkDescriptorPoolSize poolSizes[GETLENGTH(descriptorSetLayoutBinding)];
    for (uint32_t i = 0; i < GETLENGTH(descriptorSetLayoutBinding); ++i)
    {
        descriptorSetLayoutBinding[i].binding = i;
        descriptorSetLayoutBinding[i].descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
        descriptorSetLayoutBinding[i].descriptorCount = 1;
        descriptorSetLayoutBinding[i].stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
        poolSizes[i].type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
        poolSizes[i].descriptorCount = 1;
    }
    VkDescriptorSetLayoutCreateInfo descriptorSetLayoutInfo = { VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO };
    descriptorSetLayoutInfo.bindingCount = GETLENGTH(descriptorSetLayoutBinding);
    descriptorSetLayoutInfo.pBindings = descriptorSetLayoutBinding;
    vkCreateDescriptorSetLayout(ctx->device, &descriptorSetLayoutInfo, NULL, &ctx->descriptorSetLayout);


    VkPushConstantRange constantRange = { VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(uint32_t) * 2 };
    VkPipelineLayoutCreateInfo pipelineLayoutInfo = { VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO };
    pipelineLayoutInfo.setLayoutCount = 1;
    pipelineLayoutInfo.pSetLayouts = &ctx->descriptorSetLayout;
    pipelineLayoutInfo.pushConstantRangeCount = 1;
    pipelineLayoutInfo.pPushConstantRanges = &constantRange;
    VkResultCheck(vkCreatePipelineLayout(ctx->device, &pipelineLayoutInfo, NULL, &ctx->pipelineLayout));

    VkDescriptorPoolCreateInfo descriptorPoolInfo = { VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO };
    descriptorPoolInfo.poolSizeCount = GETLENGTH(descriptorSetLayoutBinding);
    descriptorPoolInfo.pPoolSizes = poolSizes;
    descriptorPoolInfo.maxSets = 1;
    descriptorPoolInfo.flags = VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT;
    VkResultCheck(vkCreateDescriptorPool(ctx->device, &descriptorPoolInfo, NULL, &ctx->descriptorPool));

    VkDescriptorSetAllocateInfo descriptorSetAllocInfo = { VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO };
    descriptorSetAllocInfo.descriptorPool = ctx->descriptorPool;
    descriptorSetAllocInfo.descriptorSetCount = 1;
    descriptorSetAllocInfo.pSetLayouts = &ctx->descriptorSetLayout;
    VkResultCheck(vkAllocateDescriptorSets(ctx->device, &descriptorSetAllocInfo, &ctx->descriptorSet));

    typedef struct
    {
        const uint32_t* data;
        const uint32_t length;
    }ShaderData;
    ShaderData lenet_forward[] =
    {
        lenet_forward0, sizeof(lenet_forward0),
        lenet_forward1, sizeof(lenet_forward1),
        lenet_forward2, sizeof(lenet_forward2),
        lenet_forward3, sizeof(lenet_forward3),
        lenet_forward4, sizeof(lenet_forward4),
        lenet_forward5, sizeof(lenet_forward5),
    };

    ShaderData lenet_backward[] =
    {
        lenet_backward0, sizeof(lenet_backward0),
        lenet_backward1, sizeof(lenet_backward1),
        lenet_backward2, sizeof(lenet_backward2),
        lenet_backward3, sizeof(lenet_backward3),
        lenet_backward4, sizeof(lenet_backward4),
        lenet_backward5, sizeof(lenet_backward5),
    };

    ShaderData lenet_delta[] =
    {
        lenet_delta0, sizeof(lenet_delta0),
        lenet_delta2, sizeof(lenet_delta2),
        lenet_delta4, sizeof(lenet_delta4),
        lenet_delta5, sizeof(lenet_delta5),
    };

    for (int i = 0; i < GETLENGTH(lenet_forward); ++i)
    {
        VkResultCheck(CreatePipelineFromSpirv(ctx, &ctx->forwardPipeline[i], lenet_forward[i].data, lenet_forward[i].length));
    }
    VkResultCheck(CreatePipelineFromSpirv(ctx, &ctx->softmaxPipeline, lenet_softmax, sizeof(lenet_softmax)));
    for (int i = 0; i < GETLENGTH(lenet_backward); ++i)
    {
        VkResultCheck(CreatePipelineFromSpirv(ctx, &ctx->backwardPipeline[i], lenet_backward[i].data, lenet_backward[i].length));
    }
    for (int i = 0; i < GETLENGTH(lenet_delta); ++i)
    {
        VkResultCheck(CreatePipelineFromSpirv(ctx, &ctx->deltaPipeline[i], lenet_delta[i].data, lenet_delta[i].length));
    }
    VkResultCheck(CreatePipelineFromSpirv(ctx, &ctx->updatePipeline, lenet_update, sizeof(lenet_update)));

    vkGetPhysicalDeviceMemoryProperties(physicalDevice[usePhysicalDevice], &ctx->physicalDeviceMemoryProperty);

    VkResultCheck(CreateStorageBuffer(ctx, &ctx->lenet, sizeof(LeNet5), VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT));
    VkResultCheck(CreateStorageBuffer(ctx, &ctx->feature, sizeof(Feature), VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT));

    return VK_SUCCESS;
}




void DestroyDeviceContext(DeviceContext* ctx)
{
    vkFreeCommandBuffers(ctx->device, ctx->commandPool, 1, &ctx->commandBuffer);
    vkFreeDescriptorSets(ctx->device, ctx->descriptorPool, 1, &ctx->descriptorSet);
    DestroyStorageBuffer(ctx, &ctx->lenet);
    DestroyStorageBuffer(ctx, &ctx->feature);
    vkDestroyDescriptorSetLayout(ctx->device, ctx->descriptorSetLayout, NULL);
    vkDestroyDescriptorPool(ctx->device, ctx->descriptorPool, NULL);
    vkDestroyPipelineLayout(ctx->device, ctx->pipelineLayout, NULL);
    for (size_t i = 0; i < GETLENGTH(ctx->forwardPipeline); ++i)
    {
        vkDestroyPipeline(ctx->device, ctx->forwardPipeline[i], NULL);
    }
    vkDestroyPipeline(ctx->device, ctx->softmaxPipeline, NULL);
    for (size_t i = 0; i < GETLENGTH(ctx->backwardPipeline); ++i)
    {
        vkDestroyPipeline(ctx->device, ctx->backwardPipeline[i], NULL);
    }
    for (size_t i = 0; i < GETLENGTH(ctx->deltaPipeline); ++i)
    {
        vkDestroyPipeline(ctx->device, ctx->deltaPipeline[i], NULL);
    }
    vkDestroyPipeline(ctx->device, ctx->updatePipeline, NULL);
    vkDestroyCommandPool(ctx->device, ctx->commandPool, NULL);
    vkDestroyDevice(ctx->device, NULL);
    vkDestroyInstance(ctx->instance, NULL);
}

VkResult CreateTrainCache(DeviceContext* ctx, TrainCache* cache, const uint32_t batchSize)
{
    cache->batchSize = batchSize;

    VkResultCheck(CreateStorageBuffer(ctx, &cache->feature, sizeof(Feature) * batchSize, VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT));
    VkResultCheck(CreateStorageBuffer(ctx, &cache->error, sizeof(Feature) * batchSize, VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT));
    VkResultCheck(CreateStorageBuffer(ctx, &cache->label, sizeof(uint32_t) * batchSize, VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT));
    VkResultCheck(CreateStorageBuffer(ctx, &cache->delta, sizeof(LeNet5) * batchSize, VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT));

    return VK_SUCCESS;
}


static uint32_t GetMaxResult(float* output, uint32_t count)
{
    uint32_t result = 0;
    for (uint32_t i = 1; i < count; ++i)
        result += (i - result) * (output[i] > output[result]);
    return result;
}


void DestroyTrainCache(DeviceContext* ctx, TrainCache* cache)
{
    DestroyStorageBuffer(ctx, &cache->feature);
    DestroyStorageBuffer(ctx, &cache->error);
    DestroyStorageBuffer(ctx, &cache->label);
    DestroyStorageBuffer(ctx, &cache->delta);
}

void LoadModel(DeviceContext* lenet, LeNet5* data)
{
    void* addr = NULL;
    vkMapMemory(lenet->device, lenet->lenet.Memory, 0, sizeof(LeNet5), 0, &addr);
    memcpy(addr, data, sizeof(LeNet5));
    vkUnmapMemory(lenet->device, lenet->lenet.Memory);
}

void SaveModel(DeviceContext* lenet, LeNet5* data)
{
    void* addr = NULL;
    vkMapMemory(lenet->device, lenet->lenet.Memory, 0, sizeof(LeNet5), 0, &addr);
    memcpy(data, addr, sizeof(LeNet5));
    vkUnmapMemory(lenet->device, lenet->lenet.Memory);
}

uint32_t Predict(DeviceContext* ctx, Feature* feature)
{
    Feature* devfeature = NULL;
    vkMapMemory(ctx->device, ctx->feature.Memory, 0, sizeof(Feature), 0, (void**)&devfeature);
    memcpy(devfeature, feature, sizeof(*feature));
    vkUnmapMemory(ctx->device, ctx->feature.Memory);

    const VkDescriptorBufferInfo descriptorBufferInfo[2] =
    {
        { ctx->lenet.Buffer, 0, VK_WHOLE_SIZE },
        { ctx->feature.Buffer, 0, VK_WHOLE_SIZE },
    };
    VkWriteDescriptorSet writeDescriptorSet = { VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET };
    writeDescriptorSet.dstSet = ctx->descriptorSet;
    writeDescriptorSet.dstBinding = 0;
    writeDescriptorSet.descriptorCount = GETLENGTH(descriptorBufferInfo);
    writeDescriptorSet.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
    writeDescriptorSet.pBufferInfo = descriptorBufferInfo;
    vkUpdateDescriptorSets(ctx->device, 1, &writeDescriptorSet, 0, NULL);

    VkCommandBufferBeginInfo commandBufferBeginInfo = { VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO };
    vkBeginCommandBuffer(ctx->commandBuffer, &commandBufferBeginInfo);
    Forward(ctx, NULL, ctx->forwardPipeline[0], GETCOUNT(devfeature->layer1));
    Forward(ctx, NULL, ctx->forwardPipeline[1], GETCOUNT(devfeature->layer2));
    Forward(ctx, NULL, ctx->forwardPipeline[2], GETCOUNT(devfeature->layer3));
    Forward(ctx, NULL, ctx->forwardPipeline[3], GETCOUNT(devfeature->layer4));
    Forward(ctx, NULL, ctx->forwardPipeline[4], GETCOUNT(devfeature->layer5));
    Forward(ctx, NULL, ctx->forwardPipeline[5], GETCOUNT(devfeature->layer6));
    vkEndCommandBuffer(ctx->commandBuffer);
    VkSubmitInfo submitInfo = { VK_STRUCTURE_TYPE_SUBMIT_INFO };
    submitInfo.commandBufferCount = 1;
    submitInfo.pCommandBuffers = &ctx->commandBuffer;
    vkQueueSubmit(ctx->queue, 1, &submitInfo, VK_NULL_HANDLE);
    vkQueueWaitIdle(ctx->queue);

    vkMapMemory(ctx->device, ctx->feature.Memory, 0, VK_WHOLE_SIZE, 0, (void**)&devfeature);
    uint8_t res = GetMaxResult(devfeature->layer6, GETCOUNT(devfeature->layer6));
    vkUnmapMemory(ctx->device, ctx->feature.Memory);
    return res;
}

void TrainBatch(DeviceContext* ctx, TrainCache* cache, Feature* feature, uint32_t* label)
{
    const VkDescriptorBufferInfo descriptorBufferInfo[] =
    {
        { ctx->lenet.Buffer, 0, VK_WHOLE_SIZE },
        { cache->feature.Buffer, 0, VK_WHOLE_SIZE },
        { cache->error.Buffer, 0, VK_WHOLE_SIZE },
        { cache->delta.Buffer, 0, VK_WHOLE_SIZE },
        { cache->label.Buffer, 0, VK_WHOLE_SIZE }
    };
    VkWriteDescriptorSet writeDescriptorSet = { VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET };
    writeDescriptorSet.dstSet = ctx->descriptorSet;
    writeDescriptorSet.dstBinding = 0;
    writeDescriptorSet.descriptorCount = GETLENGTH(descriptorBufferInfo);
    writeDescriptorSet.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
    writeDescriptorSet.pBufferInfo = descriptorBufferInfo;
    vkUpdateDescriptorSets(ctx->device, 1, &writeDescriptorSet, 0, NULL);

    Feature* devfeature = NULL;
    uint32_t* devlabel = NULL;
    vkMapMemory(ctx->device, cache->feature.Memory, 0, VK_WHOLE_SIZE, 0, (void**)&devfeature);
    vkMapMemory(ctx->device, cache->label.Memory, 0, VK_WHOLE_SIZE, 0, (void**)&devlabel);
    memcpy(devfeature, feature, sizeof(*feature) * cache->batchSize);
    memcpy(devlabel, label, sizeof(*label) * cache->batchSize);
    vkUnmapMemory(ctx->device, cache->label.Memory);
    vkUnmapMemory(ctx->device, cache->feature.Memory);

    VkCommandBufferBeginInfo commandBufferBeginInfo = { VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO };
    vkBeginCommandBuffer(ctx->commandBuffer, &commandBufferBeginInfo);
    Forward(ctx, cache, ctx->forwardPipeline[0], GETCOUNT(feature->layer1));
    Forward(ctx, cache, ctx->forwardPipeline[1], GETCOUNT(feature->layer2));
    Forward(ctx, cache, ctx->forwardPipeline[2], GETCOUNT(feature->layer3));
    Forward(ctx, cache, ctx->forwardPipeline[3], GETCOUNT(feature->layer4));
    Forward(ctx, cache, ctx->forwardPipeline[4], GETCOUNT(feature->layer5));
    Forward(ctx, cache, ctx->forwardPipeline[5], GETCOUNT(feature->layer6));
    Softmax(ctx, cache);
    Backward(ctx, cache, ctx->backwardPipeline[5], GETCOUNT(feature->layer5));
    Backward(ctx, cache, ctx->backwardPipeline[4], GETCOUNT(feature->layer4));
    Backward(ctx, cache, ctx->backwardPipeline[3], GETCOUNT(feature->layer4));
    Backward(ctx, cache, ctx->backwardPipeline[2], GETCOUNT(feature->layer2));
    Backward(ctx, cache, ctx->backwardPipeline[1], GETCOUNT(feature->layer2));
    Backward(ctx, cache, ctx->backwardPipeline[0], GETCOUNT(feature->layer0));
    Delta(ctx, cache, ctx->deltaPipeline[0], GETCOUNT(((LeNet5*)0)->weight0_1));
    Delta(ctx, cache, ctx->deltaPipeline[1], GETCOUNT(((LeNet5*)0)->weight2_3));
    Delta(ctx, cache, ctx->deltaPipeline[2], GETCOUNT(((LeNet5*)0)->weight4_5));
    Delta(ctx, cache, ctx->deltaPipeline[3], GETCOUNT(((LeNet5*)0)->weight5_6));
    Update(ctx, cache);
    vkEndCommandBuffer(ctx->commandBuffer);
    VkSubmitInfo submitInfo = { VK_STRUCTURE_TYPE_SUBMIT_INFO };
    submitInfo.commandBufferCount = 1;
    submitInfo.pCommandBuffers = &ctx->commandBuffer;
    vkQueueSubmit(ctx->queue, 1, &submitInfo, VK_NULL_HANDLE);
    vkQueueWaitIdle(ctx->queue);
}
